Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Convolutional Neural Network (CNN)


Kernel & Convolution


Kernel: a small matrix used for blurring, sharpening, embossing, edge detection etc

Convolution: adding each element of the image to its local neighbors, weighted by the kernel

In [9]:
fig = plt.figure(figsize=(18, 6))
ax = fig.add_subplot(1, 3, 1, xticks=[], yticks=[])
ax.imshow(image, cmap='gray')
ax.set_title('Original image')

sobel_x = np.array([[ -1, 0, 1], 
                    [ -2, 0, 2], 
                    [ -1, 0, 1]])

sobel_y = np.array([[ -1, -2, -1], 
                    [ 0, 0, 0], 
                    [ 1, 2, 1]])

kernels = {'Sobel x': sobel_x, 'Sobel y': sobel_y}

for i, (title, kernel) in enumerate(kernels.items()):
    filtered_img = cv2.filter2D(image, -1, kernel)
    
    ax = fig.add_subplot(1, 3, i+2, xticks=[], yticks=[])
    ax.imshow(filtered_img, cmap='gray')
    ax.set_title(title)

Typical CNN Architecture


CNN Visualisation Techniques


Saliency map


Activation maximisation


Demo 1

Introducing FlashTorch & how to visualise saliency maps


First things first - config & imports


$ pip install flashtorch
In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.models as models

Load an image


In [2]:
from flashtorch.utils import load_image

image = load_image('../examples/images/great_grey_owl_01.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Convert the PIL image to a torch tensor


In [3]:
from flashtorch.utils import apply_transforms

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Let's visualise the input first...


In [4]:
# plt.imshow(input_)
# plt.title('Input tensor')
# plt.axis('off');
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Let's visualise the input (take two)


In [5]:
from flashtorch.utils import format_for_plotting

plt.imshow(format_for_plotting(input_))
plt.title('Input tensor')
plt.axis('off');
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Let's visualise the input (take THREE)


In [6]:
from flashtorch.utils import denormalize

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');

Load a pre-trained model & create a backprop object


In [7]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, take_max=False)

Retrieve the class index for the object in the input


In [8]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(target_class)
24

It kind of does fuzzy-matching as well (to some extent...)


In [9]:
# imagenet['dog']
ValueError: Multiple potential matches found: maltese dog, old english sheepdog, shetland sheepdog, greater swiss mountain dog, bernese mountain dog, french bulldog, eskimo dog, african hunting dog, dogsled, hotdog

Finally! It's time to calculate the gradients of each pixel w.r.t. the input image


In [10]:
gradients = backprop.calculate_gradients(input_, target_class)

print(type(gradients), gradients.shape)
<class 'torch.Tensor'> torch.Size([3, 224, 224])

You can also take the maximum of the gradients across colour channels


In [11]:
max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(max_gradients), max_gradients.shape)
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's inspect gradients by plotting them out


In [12]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)

It kind of shows that pixels around the area where the animal is present had strongest positive effects on the prediction.

But it's quite noisy...

Guided backprop to the rescue!


In [13]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)
max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)
In [14]:
visualize(input_, guided_gradients, max_guided_gradients)

Now that's much less noisy!

We can clearly see that pixels around the head and eyes had strongest positive effects on the prediction.

What about a jay?


In [16]:
visualize(input_, guided_gradients, max_guided_gradients)

Or an oystercatcher...


In [18]:
visualize(input_, guided_gradients, max_guided_gradients)

Demo 2

Using FlashTorch to gain additional insights on transfer learning


Transfer Learning


  • A model developed for a task is reused as a starting point another task

  • Pre-trained models often used in computer visions & natural language processing tasks

  • Save compute & time resources

Flower Classifier


From: Densenet model, pre-trained on ImageNet (1000 classes)

To: Flower classifier to recognise 102 species of flowers, using a dataset from VGG group.

In [20]:
image = load_image('../examples/images/foxglove.jpg')
input_ = apply_transforms(image)

class_index = 96  # foxglove

pretrained_model = create_model()

backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:93: UserWarning: The predicted class does not equal the
                target class. Calculating the gradient with respect to the
                predicted class.
  predicted class.'''))
In [21]:
trained_model = create_model('../models/flower_classification_transfer_learning.pt')

backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)